torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed \
--log_dir logs/imagenet/d=4096--i=1024--heads=16--sd=8--nlm=64--synch=8192-2048-32-h=64-random-pairing--iters=50x25--backbone=152x4 \
--model ctm \
--dataset imagenet \
--d_model 4096 \
--d_input 1024 \
--synapse_depth 8 \
--heads 16 \
--n_synch_out 8196 \
--n_synch_action 2048 \
--n_random_pairing_self 32 \
--neuron_select_type random-pairing \
--iterations 50 \
--memory_length 25 \
--deep_memory \
--memory_hidden_dims 64 \
--dropout 0.2 \
--dropout_nlm 0 \
--no-do_normalisation \
--positional_embedding_type none \
--backbone_type resnet152-4 \
--batch_size 64 \
--batch_size_test 64 \
--n_test_batches 200 \
--lr 5e-4 \
--gradient_clipping 20 \
--training_iterations 500001 \
--save_every 1000 \
--track_every 5000 \
--warmup_steps 10000 \
--use_scheduler \
--scheduler_type cosine \
--weight_decay 0.0 \
--seed 1 \
--use_amp \
--reload  \
--num_workers_train 8 \
--use_custom_sampler
